from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *

#
# This section runs the numerical experiment designed
# to test the robustness of RIP to leaked spectral weight
#

probe_maxk = 128
resolution_ratio = 0.4

# Set GPU to false if you don't have a GPU
GPU = True

# This defines the size of the reciprocal space array we will simulate the
# probe on.
big_stage = int(probe_maxk) * 4
probe_fov_radius = (big_stage * 4)//10

# This generates our test probe
nearfield = generate_blr_probe([big_stage,big_stage], probe_fov_radius, probe_maxk)

# And we set up the array size for our retrieval
retrieval_size = int(probe_maxk * resolution_ratio) * 2

# This defines the padding for our error checking, such that we calculate
# the resulting error from the central half of the image
ec_padding =int(retrieval_size * 1/4)


losses = []
errors = []

# Thesea re the power ratios we will check
power_ratios = np.linspace(0,0.5,70)

for power_ratio in power_ratios:
    loop_losses = []
    loop_errors = []

    for attempt in range(200):

        # We generate a random image, defined in Fourier space
        real = np.random.randn(2*probe_maxk, 2*probe_maxk)
        imag = np.random.randn(2*probe_maxk, 2*probe_maxk)
        original_image = real + 1j * imag

        # Then, we calculate how much power is contained in the
        # region we will attempt to reconstruct
        padding = (real.shape[0] - retrieval_size)//2
        central_region = np.s_[padding:-padding,padding:-padding]
        central_power = np.sum(np.abs(original_image[central_region])**2)

        # And how much power is contained outside
        outer_power = np.sum(np.abs(original_image)**2) - central_power

        # Finally, we scale the image to  reduce the outer power
        # And replace the central region with the unscaled version
        central_cutout = np.copy(original_image[central_region])
        original_image *= np.sqrt(power_ratio / outer_power * central_power)
        original_image[central_region] = central_cutout

        # This is an optional check that we actually have implemented
        # the correct power ratio
        # oi = original_image
        # central_power = np.sum(np.abs(oi[central_region])**2)
        # outer_power = np.sum(np.abs(oi)**2) - central_power
        # print('power ratio', power_ratio,  outer_power / central_power)

        # Now we find the real space representation of the image
        original_image = backpropagate(complex_to_torch(original_image)).to(dtype=t.float32)
        

        # We simulate and normalize diffraction intensities
        simulation_nearfield = expand_probe(nearfield, original_image)
        exit_wave = interact(original_image, simulation_nearfield)
        diffraction = propagate(exit_wave)
        intensities = measure(diffraction, nearfield.shape)
        intensities /= t.sum(intensities)
        intensities *= 10 * 512**2
        
        # We then attempt a phase retrieval
        retrieval, loss = reconstruct(intensities, nearfield, retrieval_size, 0.4, 1000, GPU=GPU, schedule=True)
        ret = torch_to_complex(retrieval)
        im = torch_to_complex(original_image)

        # We calculate the RMS  error
        error, originals, comparisons = compare_images(ret, im, np.s_[ec_padding:-ec_padding,ec_padding:-ec_padding])

        # And save out the loss and error
        loop_losses.append(loss[-1])
        loop_errors.append(error)
        
        print('Power Ratio','%.2f' % power_ratio,'Had Loss', loss[-1], 'And Error', loop_errors[-1])

    losses.append(loop_losses)
    errors.append(loop_errors)
    
    

results = {'probe': nearfield, 'losses':losses, 'power_ratios': power_ratios, 'errors':errors}
with open('data/bandlimiting robustness trial.pickle', 'wb') as f:
    pickle.dump(results, f)
